{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize MLeap libraries before Scikit/Pandas\n",
    "import mleap.sklearn.preprocessing.data\n",
    "import mleap.sklearn.pipeline\n",
    "from mleap.sklearn.ensemble import forest\n",
    "from mleap.sklearn.preprocessing.data import FeatureExtractor\n",
    "\n",
    "# Import Scikit Transformer(s)\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')\n",
    "input_features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']\n",
    "\n",
    "output_vector_name = 'extracted_features' # Used only for serialization purposes\n",
    "output_features = [x for x in input_features]\n",
    "\n",
    "feature_extractor_tf = FeatureExtractor(input_scalars=input_features,\n",
    "                                        output_vector=output_vector_name,\n",
    "                                        output_vector_items=output_features)\n",
    "\n",
    "classification_tf = RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
    "                       max_depth=2, max_features='auto', max_leaf_nodes=None,\n",
    "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
    "                       min_samples_leaf=1, min_samples_split=2,\n",
    "                       min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n",
    "                       oob_score=False, random_state=0, verbose=0, warm_start=False)\n",
    "\n",
    "classification_tf.mlinit(input_features='features', prediction_column='species',feature_names=\"features\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('vector_assembler_b17f1328-c767-11e8-85f3-88e9fe56e175', FeatureExtractor(input_scalars=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'],\n",
       "         input_vectors=None, output_vector='extracted_features',\n",
       "         output_vector_items=['sepal_length', 'sepal_width', 'petal_length'...estimators=10, n_jobs=1,\n",
       "            oob_score=False, random_state=0, verbose=0, warm_start=False))])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_pipeline = Pipeline([(feature_extractor_tf.name, feature_extractor_tf),\n",
    "                        (classification_tf.name, classification_tf)])\n",
    "rf_pipeline.mlinit()\n",
    "rf_pipeline.fit(data[input_features],data['species'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_pipeline.serialize_to_bundle('/tmp', 'mleap-scikit-rf-pipeline', init=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['__abstractmethods__', '__class__', '__delattr__', '__dict__', '__doc__', '__format__', '__getattribute__', '__getstate__', '__hash__', '__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_abc_cache', '_abc_negative_cache', '_abc_negative_cache_version', '_abc_registry', '_estimator_type', '_final_estimator', '_fit', '_get_param_names', '_get_params', '_inverse_transform', '_pairwise', '_replace_estimator', '_set_params', '_transform', '_validate_names', '_validate_steps', 'classes_', 'decision_function', 'deserialize_from_bundle', 'fit', 'fit_predict', 'fit_transform', 'get_params', 'inverse_transform', 'memory', 'mlinit', 'name', 'named_steps', 'op', 'predict', 'predict_log_proba', 'predict_proba', 'score', 'serializable', 'serialize_to_bundle', 'set_params', 'steps', 'transform']\n"
     ]
    }
   ],
   "source": [
    "print(dir(pipeline))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa',\n",
       "       'setosa', 'setosa', 'versicolor', 'versicolor', 'virginica',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'virginica', 'versicolor', 'virginica', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'virginica',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'virginica', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'versicolor', 'versicolor',\n",
       "       'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'versicolor', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica', 'virginica', 'virginica', 'virginica',\n",
       "       'virginica', 'virginica'], dtype=object)"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_pipeline.predict(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    }
   ],
   "source": [
    "print(classification_tf.n_classes_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
